class DeformableTransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model=256,
        d_ffn=1024,
        dropout=0.1,
        activation="relu",
        n_levels=4,
        n_heads=8,
        n_points=4,
    ):
        super().__init__()

        #* entity 
        # cross attention
        self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)


        #? CSA
        self.self_attn_so = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2_so = nn.Dropout(dropout)
        self.norm2_so = nn.LayerNorm(d_model)


        #? SUB DVA
        # self.cross_attn_sub = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.cross_attn_sub = MSDeformAttn(d_model, n_levels, n_heads, n_points)
        self.dropout1_sub = nn.Dropout(dropout)
        self.norm1_sub = nn.LayerNorm(d_model)


        #? SUB DEA
        self.cross_sub_entity = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2_sub = nn.Dropout(dropout)
        self.norm2_sub = nn.LayerNorm(d_model)

        #? OBJ DVA
        self.cross_attn_obj =  MSDeformAttn(d_model, n_levels, n_heads, n_points)
        self.dropout1_obj = nn.Dropout(dropout)
        self.norm1_obj = nn.LayerNorm(d_model)

        #? OBJ DEA
        self.cross_obj_entity = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2_obj = nn.Dropout(dropout)
        self.norm2_obj = nn.LayerNorm(d_model)

        #? ffn
        self.linear1_sub = nn.Linear(d_model, d_ffn)
        self.dropout3_sub = nn.Dropout(dropout)
        self.linear2_sub = nn.Linear(d_ffn, d_model)
        self.dropout4_sub = nn.Dropout(dropout)
        self.norm3_sub = nn.LayerNorm(d_model)

        self.linear1_obj = nn.Linear(d_model, d_ffn)
        self.dropout3_obj = nn.Dropout(dropout)
        self.linear2_obj = nn.Linear(d_ffn, d_model)
        self.dropout4_obj = nn.Dropout(dropout)
        self.norm3_obj = nn.LayerNorm(d_model)



    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, tgt):
        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm3(tgt)
        return tgt
    

    def forward_ffn_sub(self, tgt):
        tgt2 = self.linear2_sub(self.dropout3_sub(self.activation(self.linear1_sub(tgt))))
        tgt = tgt + self.dropout4_sub(tgt2)
        tgt = self.norm3_sub(tgt)
        return tgt
    
    def forward_ffn_obj(self, tgt):
        tgt2 = self.linear2_obj(self.dropout3_obj(self.activation(self.linear1_obj(tgt))))
        tgt = tgt + self.dropout4_obj(tgt2)
        tgt = self.norm3_obj(tgt)
        return tgt

    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def forward(
        self,
        tgt,
        tgt_cdecoder,
        query_pos,
        reference_points,
        reference_points_cdecoder,
        src,
        src_spatial_shapes,
        level_start_index,
        src_padding_mask=None,
        self_attn_mask=None,
        lid = None,

        #*===============================
        tgt_sub = None,
        tgt_obj = None,
        tgt_sub_cdecoder = None,
        tgt_obj_cdecoder = None,
        query_pos_sub = None,
        query_pos_obj = None,
        reference_points_input_sub = None,
        reference_points_input_obj = None,
        reference_points_input_sub_cdecoder = None,
        reference_points_input_obj_cdecoder = None,

        so_embed = None,
        attn_mask_tri = None
    ):
        #*==================Entity================================
        q = k = self.with_pos_embed(tgt, query_pos) # [bs, n, 256]
        tgt2 = self.self_attn(
            q.transpose(0, 1),
            k.transpose(0, 1),
            tgt.transpose(0, 1),
            attn_mask=self_attn_mask,
        )[0].transpose(0, 1)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt_og_input = tgt


        tgt2 = self.cross_attn(
            self.with_pos_embed(tgt, query_pos),
            reference_points,
            src,
            src_spatial_shapes,
            level_start_index,
            src_padding_mask,
        )
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt = self.forward_ffn(tgt) # [bs, n, 256]


        if lid == 0:
            tgt_cdecoder = tgt_og_input
        tgt2_cdecoder = self.cross_attn(
            self.with_pos_embed(tgt_cdecoder, query_pos),
            reference_points_cdecoder,
            src,
            src_spatial_shapes,
            level_start_index,
            src_padding_mask,
        )
        tgt_cdecoder = tgt_cdecoder + self.dropout1(tgt2_cdecoder)
        tgt_cdecoder = self.norm1(tgt_cdecoder)
        tgt_cdecoder = self.forward_ffn(tgt_cdecoder) # [bs, n, 256]
        #*==================Entity================================


        #*==================Relation================================
        if attn_mask_tri is not None:
            attn_mask_csa, attn_mask_dea = attn_mask_tri
        else:
            attn_mask_csa = attn_mask_dea = None
        q_sub = k_sub = self.with_pos_embed(self.with_pos_embed(tgt_sub, query_pos_sub), so_embed[0]) # bs 300 256
        q_obj = k_obj = self.with_pos_embed(self.with_pos_embed(tgt_obj, query_pos_obj), so_embed[1])
        q_so = torch.cat((q_sub, q_obj), dim=1)  # bs 600 256
        k_so = torch.cat((k_sub, k_obj), dim=1)  # bs 600 256
        tgt_so = torch.cat((tgt_sub, tgt_obj), dim=1)   # bs 600 256
        tgt2_so = self.self_attn_so(q_so.transpose(0, 1), k_so.transpose(0, 1), tgt_so.transpose(0, 1), attn_mask=attn_mask_csa)[0].transpose(0, 1)  # bs 600 256
        tgt_so = tgt_so + self.dropout2_so(tgt2_so)
        tgt_so = self.norm2_so(tgt_so) # bs 600 256
        t_num = query_pos_sub.shape[1]
        tgt_sub, tgt_obj = torch.split(tgt_so, t_num, dim=1) #  [bs,n,256]  [bs,n,256]

        tgt_og_input_sub = tgt_sub
        tgt_og_input_obj = tgt_obj



        #? SUB DVA
        # bs 300 256
        tgt2_sub = self.cross_attn_sub( self.with_pos_embed(tgt_sub, query_pos_sub),
                                        reference_points_input_sub,
                                        src,
                                        src_spatial_shapes,
                                        level_start_index,
                                        src_padding_mask)
        tgt_sub = tgt_sub + self.dropout1_sub(tgt2_sub)
        tgt_sub = self.norm1_sub(tgt_sub)

        #? SUB DVA one2many
        if lid == 0:
            tgt_sub_cdecoder = tgt_og_input_sub
        tgt2_sub_cdecoder = self.cross_attn_sub( self.with_pos_embed(tgt_sub_cdecoder, query_pos_sub),
                                                reference_points_input_sub_cdecoder,
                                                src,
                                                src_spatial_shapes,
                                                level_start_index,
                                                src_padding_mask)
        tgt_sub_cdecoder = tgt_sub_cdecoder + self.dropout1_sub(tgt2_sub_cdecoder)
        tgt_sub_cdecoder = self.norm1_sub(tgt_sub_cdecoder)

        #* SUB DEA
        tgt2_sub = self.cross_sub_entity(query=self.with_pos_embed(tgt_sub, query_pos_sub).transpose(0, 1),
                                         key=tgt.transpose(0, 1), value=tgt.transpose(0, 1), attn_mask=attn_mask_dea)[0].transpose(0, 1) # bs 300 256
        tgt_sub = tgt_sub + self.dropout2_sub(tgt2_sub)
        tgt_sub = self.norm2_sub(tgt_sub)
        tgt_sub = self.forward_ffn_sub(tgt_sub) # bs 300 256

        #* SUB DEA one2many
        tgt2_sub_cdecoder = self.cross_sub_entity(query=self.with_pos_embed(tgt_sub_cdecoder, query_pos_sub).transpose(0, 1),
                                         key=tgt_cdecoder.transpose(0, 1), value=tgt_cdecoder.transpose(0, 1), attn_mask=attn_mask_dea)[0].transpose(0, 1)

        tgt_sub_cdecoder = tgt_sub_cdecoder + self.dropout2_sub(tgt2_sub_cdecoder)
        tgt_sub_cdecoder = self.norm2_sub(tgt_sub_cdecoder)
        tgt_sub_cdecoder = self.forward_ffn_sub(tgt_sub_cdecoder)


        #? OBJ DVA
        tgt2_obj = self.cross_attn_obj( self.with_pos_embed(tgt_obj, query_pos_obj),
                                        reference_points_input_obj,
                                        src,
                                        src_spatial_shapes,
                                        level_start_index,
                                        src_padding_mask)
        tgt_obj = tgt_obj + self.dropout1_obj(tgt2_obj)  # bs 300 256
        tgt_obj = self.norm1_obj(tgt_obj)

        #? OBJ DVA one2many
        if lid == 0:
            tgt_obj_cdecoder = tgt_og_input_obj
        tgt2_obj_cdecoder = self.cross_attn_obj(self.with_pos_embed(tgt_obj_cdecoder, query_pos_obj),
                                                reference_points_input_obj_cdecoder,
                                                src,
                                                src_spatial_shapes,
                                                level_start_index,
                                                src_padding_mask)
        tgt_obj_cdecoder = tgt_obj_cdecoder + self.dropout1_obj(tgt2_obj_cdecoder)
        tgt_obj_cdecoder = self.norm1_obj(tgt_obj_cdecoder)

        #* OBJ DEA
        tgt2_obj = self.cross_obj_entity(query=self.with_pos_embed(tgt_obj, query_pos_obj).transpose(0, 1),
                                    key=tgt.transpose(0, 1), value=tgt.transpose(0, 1), attn_mask=attn_mask_dea)[0].transpose(0, 1)
        tgt_obj = tgt_obj + self.dropout2_obj(tgt2_obj) 
        tgt_obj = self.norm2_obj(tgt_obj) 
        tgt_obj = self.forward_ffn_obj(tgt_obj) # bs 300 256

        #* OBJ DEA one2many
        tgt2_obj_cdecoder = self.cross_obj_entity(query=self.with_pos_embed(tgt_obj_cdecoder, query_pos_obj).transpose(0, 1),
                                                key=tgt_cdecoder.transpose(0, 1), value=tgt_cdecoder.transpose(0, 1), attn_mask=attn_mask_dea)[0].transpose(0, 1)
        tgt_obj_cdecoder = tgt_obj_cdecoder + self.dropout2_obj(tgt2_obj_cdecoder)
        tgt_obj_cdecoder = self.norm2_obj(tgt_obj_cdecoder)
        tgt_obj_cdecoder = self.forward_ffn_obj(tgt_obj_cdecoder)



        return tgt, tgt_cdecoder, tgt_sub, tgt_obj, tgt_sub_cdecoder, tgt_obj_cdecoder